﻿using System;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlClient;

namespace QQ2564874169.RelationalSql
{
    public class SqlserverSql : Sql
    {
        private SqlConnection _conn;
        private SqlTransaction _transaction;

        private static string CreateConnectString(string host, string database, string user, string pwd)
        {
            if (host.Contains(":"))
            {
                var items = host.Split(':');
                int port;
                if (items.Length != 2 || int.TryParse(items[1], out port) == false)
                {
                    throw new FormatException($"host格式不正确：{host}。");
                }
                host = host.Replace(":", ",");
            }
            return
                $"Data Source={host};Initial Catalog={database};Integrated Security=false;User ID={user};Password={pwd}";
        }

        public SqlserverSql(string host, string database, string user, string pwd)
            : this(CreateConnectString(host, database, user, pwd))
        {

        }

        public SqlserverSql(string connectString)
        {
            _conn = new SqlConnection(connectString);
        }

        private SqlCommand CreateCommand(SqlContext context)
        {
            var cmd = _conn.CreateCommand();
            cmd.CommandText = context.Command;
            cmd.CommandType = context.IsProc ? CommandType.StoredProcedure : CommandType.Text;
            if (_transaction != null)
            {
                cmd.Transaction = _transaction;
            }
            if (context.Timeout.HasValue)
            {
                cmd.CommandTimeout = context.Timeout.Value;
            }
            var param = context.Param as ICollection<SqlParamValue>;
            if (param != null)
            {
                foreach (var item in param)
                {
                    cmd.Parameters.Add(ParamConvert(item));
                }
            }
            return cmd;
        }

        protected override void OnDispose()
        {
            if (_transaction != null)
            {
                _transaction.Rollback();
                _transaction = null;
            }
            if (_conn != null)
            {
                _conn.Close();
                _conn.Dispose();
                _conn = null;
            }
        }

        protected override SqlBeforeEventArgs OnBefore(SqlContext context)
        {
            if (_conn.State != ConnectionState.Open)
            {
                _conn.Open();
            }
            return base.OnBefore(context);
        }

        protected override void OnAfter(SqlBeforeEventArgs before)
        {
            if (!InTransaction)
            {
                _conn.Close();
            }
            base.OnAfter(before);
        }

        protected override int OnExecute(SqlContext context)
        {
            return CreateCommand(context).ExecuteNonQuery();
        }

        protected override T OnQueryScalar<T>(SqlContext context)
        {
            return (T) CreateCommand(context).ExecuteScalar();
        }

        protected override IEnumerable<T> OnQuery<T>(SqlContext context)
        {
            var cmd = CreateCommand(context);

            using (var reader = cmd.ExecuteReader())
            {
                return reader.ToArray<T>();
            }
        }

        protected override IMultipleReader OnQueryMultiple(SqlContext context)
        {
            var cmd = CreateCommand(context);

            return new MultipleReader(cmd.ExecuteReader());
        }

        protected override void OnBeginTransaction()
        {
            if (_transaction == null)
            {
                if (_conn.State != ConnectionState.Open)
                    _conn.Open();
                _transaction = _conn.BeginTransaction();
            }
        }

        protected override void OnCommitTransaction()
        {
            if (_transaction != null)
            {
                _transaction.Commit();
                _transaction = null;
                _conn.Close();
            }
        }

        protected override void OnRollbackTransaction()
        {
            if (_transaction != null)
            {
                _transaction.Rollback();
                _transaction = null;
                _conn.Close();
            }
        }

        private static SqlParameter ParamConvert(SqlParamValue item)
        {
            var p = new SqlParameter(item.Name, item.Value);
            if (item.Size.HasValue)
            {
                p.Size = item.Size.Value;
            }
            if (item.Type is SqlDbType)
            {
                p.SqlDbType = (SqlDbType) item.Type;
            }
            else if (item.Type is DbType)
            {
                p.DbType = (DbType) item.Type;
            }
            return p;
        }
    }
}
